import pandas as pd
import re
import numpy as np
import functools
from sklearn import preprocessing
from pathlib import Path
from scipy.stats import kendalltau, spearmanr


foundations = ["care", "fairness", "loyalty", "authority", "sanctity"]
poles = ["virtue", "vice"]
languages = ["es", "nl", "de"]
main_dir = Path(__file__).resolve().parents[1]
list_of_all_column_names = [
    f"{foundation}.{pole}" for foundation in foundations for pole in poles
]


def clean_up_columns(data: pd.DataFrame):
    data = data.rename(
        columns={c: c.replace("harm", "care") for c in data.columns if "harm" in c}
    )
    data = data.rename(
        columns={
            c: c.replace("purity", "sanctity") for c in data.columns if "purity" in c
        }
    )
    for pole in poles:
        for foundation in foundations:
            data = data.rename(
                columns={
                    c: c.replace(f"{foundation}_{pole}", f"{foundation}.{pole}")
                    for c in data.columns
                }
            )
    return data


def load_moral_bert(aggregate: bool = True, explicit_manifesto_id: bool = False):
    data = pd.read_feather(main_dir / "data/moralbert/moralbert_scores.feather")
    data = data.rename(
        columns={
            "care": "moralbert_care.virtue",
            "harm": "moralbert_care.vice",
            "fairness": "moralbert_fairness.virtue",
            "cheating": "moralbert_fairness.vice",
            "loyalty": "moralbert_loyalty.virtue",
            "betrayal": "moralbert_loyalty.vice",
            "authority": "moralbert_authority.virtue",
            "subversion": "moralbert_authority.vice",
            "purity": "moralbert_sanctity.virtue",
            "degradation": "moralbert_sanctity.vice",
            "index": "id_for_project",
            "text_en": "sentence",
        }
    )

    # data["id_for_project"] = data["id_for_project"] + 1
    if aggregate:
        grouping_helper_df = get_meta_data_file()
        data = data.merge(grouping_helper_df, on="id_for_project", how="left")
        data = data.drop(["id_for_project", "sentence"], axis=1)
        #    if "language_iso" in data.columns:
        #       data = data.drop("language_iso", axis=1)
        data = data.groupby("manifesto_id").mean()
        if explicit_manifesto_id:
            if "index" in data.columns:
                data = data.drop("index", axis=1)
            data = data.rename_axis("manifesto_id").reset_index()
    return data


def load_ddr(
    which: str,
    aggregate: bool = True,
    explicit_manifesto_id: bool = False,
):
    if which == "mono":
        data = clean_up_columns(pd.read_feather(main_dir / "data/ddr/all_en.feather"))

    elif which == "merged":
        en_data = load_ddr(which="mono", aggregate=False)
        en_data = en_data.rename(
            columns={
                c: c.replace("ddr_all_en", "english_anchor")
                for c in en_data.columns
                if c not in ["index", "id_for_project"]
            }
        )
        list_of_multi_data = []
        for language in languages:
            list_of_multi_data.append(
                clean_up_columns(
                    pd.read_feather(main_dir / f"data/ddr/sentences_{language}.feather")
                )
            )
        multidata = pd.concat(list_of_multi_data, axis=0)
        multidata = multidata.rename(
            columns={
                c: c.replace("ddr_original_language", "ddr")
                for c in multidata.columns
                if c not in ["index", "id_for_project"]
            }
        )
        data = pd.merge(en_data, multidata, on="id_for_project").drop(
            columns=["index_x", "index_y"]
        )
        # data = add_language_iso_column(data)

    if aggregate:
        grouping_helper_df = get_meta_data_file()
        data = data.merge(grouping_helper_df, on="id_for_project", how="left")
        data = data.drop("id_for_project", axis=1)
        #    if "language_iso" in data.columns:
        #       data = data.drop("language_iso", axis=1)
        data = data.groupby("manifesto_id").mean()
        if explicit_manifesto_id:
            if "index" in data.columns:
                data = data.drop("index", axis=1)
            data = data.rename_axis("manifesto_id").reset_index()
    return data


def load_ccr(
    which: str,
    aggregate: bool = True,
    explicit_manifesto_id: bool = False,
):
    if which == "en_to_en":
        temp_data = pd.read_feather(main_dir / f"data/ccr/ccr_{which}.feather")
        temp_data = temp_data.rename(
            columns={c: c.replace(" ", "") for c in temp_data.columns}
        )
        data = clean_up_columns(temp_data)
    else:
        temp_multi = pd.read_feather(main_dir / f"data/ccr/ccr_{which}.feather")
        temp_multi = temp_multi.rename(
            columns={c: c.replace(" ", "") for c in temp_multi.columns}
        )
        temp_multi = clean_up_columns(temp_multi)
        temp_english = load_ccr(which="en_to_en", aggregate=False)
        temp_english = temp_english.rename(
            columns={
                c: c.replace("ccr_en_to_en", "english_anchor")
                for c in temp_english.columns
            }
        )
        data = pd.merge(temp_english, temp_multi, on="id_for_project").drop(
            columns=["index_x", "index_y"]
        )
    # data = add_language_iso_column(data)

    if aggregate:
        grouping_helper_df = get_meta_data_file()
        data = data.merge(grouping_helper_df, on="id_for_project", how="left")
        data = data.drop("id_for_project", axis=1)
        # if not language_iso_column:
        #    if "language_iso" in data.columns:
        #        data = data.drop("language_iso", axis=1)
        data = data.groupby("manifesto_id").mean()
        if explicit_manifesto_id:
            if "index" in data.columns:
                data = data.drop("index", axis=1)
            data = data.rename_axis("manifesto_id").reset_index()
    return data


def load_mfd(
    which: str,
    aggregate: bool = True,
    how_to_merge: str = "left",
    explicit_manifesto_id: bool = False,
):
    if which == "mono":
        grouping_helper_df = get_meta_data_file()
        data = {}
        for method in ["mfd", "mfd2", "emfd"]:
            temp_data = pd.read_feather(main_dir / f"data/mfd/{method}_scores.feather")
            temp_data = temp_data.rename(
                columns={
                    c: f"{method}_{c}"
                    for c in temp_data.columns
                    if c not in ["index", "id_for_project"]
                }
            )
            if aggregate:
                temp_data = temp_data.merge(
                    grouping_helper_df, on="id_for_project", how="left"
                )
            if aggregate:
                temp_data = temp_data.groupby("manifesto_id").mean()
                if "index" in temp_data.columns:
                    temp_data = data.drop("index", axis=1)
            if explicit_manifesto_id:
                temp_data = temp_data.rename_axis("manifesto_id").reset_index()
            data[method] = clean_up_columns(temp_data)
    elif which == "merged":
        mfd_en = load_mfd(which="mono", aggregate=False)["mfd"]
        mfd_en = mfd_en.rename(
            columns={
                c: c.replace("mfd", "english_anchor")
                for c in mfd_en.columns
                if c not in ["index", "id_for_project"]
            }
        )
        list_of_multi_data = []
        for language in languages:
            list_of_multi_data.append(
                clean_up_columns(
                    pd.read_feather(
                        main_dir / f"data/mfd/mfd_{language}_results.feather"
                    )
                )
            )
            # vertically concat all data
            # rename the docname variable that comes from quanteda
            # keep only the dimensions of the MFD
        multidata = pd.concat(list_of_multi_data, axis=0)
        multidata = multidata.rename(columns={"docname": "id_for_project"})
        multidata = multidata[[*list_of_all_column_names, "id_for_project"]]
        multidata["id_for_project"] = pd.to_numeric(multidata["id_for_project"])
        multidata = multidata.rename(
            columns={
                c: f"mfd_{c}"
                for c in multidata.columns
                if c not in ["index", "id_for_project"]
            }
        )
        if how_to_merge == "left":
            data = pd.merge(mfd_en, multidata, on="id_for_project", how="left")
        else:
            data = pd.merge(
                mfd_en,
                multidata,
                on="id_for_project",
            )
        # data = add_language_iso_column(data)

        if aggregate:
            grouping_helper_df = get_meta_data_file()
            data = data.merge(grouping_helper_df, on="id_for_project", how="left")
            data = data.drop("id_for_project", axis=1)
            # if not language_iso_column:
            #    if "language_iso" in data.columns:
            #        data = data.drop("language_iso", axis=1)
            data = data.groupby("manifesto_id").mean()
            if explicit_manifesto_id:
                if "index" in data.columns:
                    data = data.drop("index", axis=1)
                data = data.rename_axis("manifesto_id").reset_index()
    return data


def add_language_iso_column(data: pd.DataFrame, matching_key: str):
    original_data = pd.read_feather(main_dir / "data/source/manifesto_corpus.feather")
    enhanced_data = pd.merge(
        data, original_data[[matching_key, "language_iso"]], on=matching_key
    )
    return enhanced_data


def save_correlation_tables(
    strategy: str, normalise: bool = False, correlation_measure: str = "kendall"
):
    Path(main_dir / "data/for_graphs/").mkdir(parents=True, exist_ok=True)
    # grouping_helper_df = get_meta_data_file()
    if strategy == "same_category" or strategy == "same_method":
        dict_of_results = {
            "ddr": load_ddr(which="mono"),
            "ccr": load_ccr(which="en_to_en"),
            "moralbert": load_moral_bert(),
            **load_mfd(which="mono"),
        }
    if strategy == "same_category":
        Path(main_dir / "data/for_graphs/same_category/").mkdir(
            parents=True, exist_ok=True
        )

        for pole in poles:
            list_of_correlations = []
            for foundation in foundations:
                # get a list of dfs, where each df only contains sentence_id for merging and the currently_interesting_measurement
                narrow_results = [
                    df.loc[
                        :,
                        df.columns.str.contains(f"{foundation}.{pole}")
                        | df.columns.isin(["manifesto_id"]),
                    ]
                    for df in dict_of_results.values()
                ]
                # idea from here https://stackoverflow.com/questions/34338831/pandas-merge-multiple-dataframes-and-control-column-names
                merge = functools.partial(pd.merge, on=["manifesto_id"])
                result = functools.reduce(merge, narrow_results)
                # result = result.drop("manifesto_id", axis=1)

                # only keep the name of the method
                result = result.rename(
                    columns={
                        c: re.findall("(.*)_[a-z]*.[a-z]*$", c)[0]
                        for c in result.columns
                    }
                )
                correlations = result.corr(method=correlation_measure)
                correlations.reset_index().to_feather(
                    main_dir
                    / f"data/for_graphs/same_category/correlations_{foundation}_{pole}_{correlation_measure}.feather"
                )
                if correlation_measure == "kendall":
                    pvalues = result.corr(
                        method=lambda x, y: kendalltau(x, y)[1]
                    ) - np.eye(*correlations.shape)
                elif correlation_measure == "spearman":
                    pvalues = result.corr(
                        method=lambda x, y: spearmanr(x, y)[1]
                    ) - np.eye(*correlations.shape)
                else:
                    raise ValueError()
                pvalues.reset_index().to_feather(
                    main_dir
                    / f"data/for_graphs/same_category/correlations_{foundation}_{pole}_pvalues_{correlation_measure}.feather"
                )
                list_of_correlations.append(result.corr(method=correlation_measure))

            mean_corr = np.mean(list_of_correlations, axis=0)
            mean_corr_df = pd.DataFrame(
                data=mean_corr,
                index=list_of_correlations[-1].index,
                columns=list_of_correlations[-1].columns,
            )
            mean_corr_df.reset_index().to_feather(
                main_dir
                / f"data/for_graphs/same_category/mean_correlation_{pole}_{correlation_measure}.feather"
            )
    elif strategy == "same_method":
        Path(main_dir / "data/for_graphs/same_method/").mkdir(
            parents=True, exist_ok=True
        )
        for method, data in dict_of_results.items():
            result = data[
                data.columns[
                    data.columns.str.contains("|".join(list_of_all_column_names))
                    | data.columns.isin(["manifesto_id"])
                ]
            ]
            # result = result.merge(grouping_helper_df, on="id_for_project", how="left")
            # result = result.drop("id_for_project", axis=1)

            result = result.rename(
                columns={
                    c: re.findall(".*_([a-z]*.[a-z]*)$", c)[0]
                    for c in result.columns
                    # if c not in ["id_for_project"]
                }
            )
            result.to_feather(
                main_dir / f"data/for_graphs/same_method/data_{method}.feather"
            )
            # result = result.drop("id_for_project", axis=1)
            extra_string_normalise = ""

            if normalise:
                # scaler = StandardScaler()
                quantile_transformer = preprocessing.MinMaxScaler()
                result[result.columns] = quantile_transformer.fit_transform(
                    result[result.columns]
                )
                extra_string_normalise = "_normalised"

            result.to_feather(
                main_dir
                / f"data/for_graphs/same_method/data_{method}{extra_string_normalise}.feather"
            )
            if result.size > 0:
                correlations = result.corr(correlation_measure)
                correlations.reset_index().to_feather(
                    main_dir
                    / f"data/for_graphs/same_method/correlation_{method}_{correlation_measure}.feather"
                )
                if correlation_measure == "kendall":
                    pvalues = result.corr(
                        method=lambda x, y: kendalltau(x, y)[1]
                    ) - np.eye(*correlations.shape)
                elif correlation_measure == "spearman":
                    pvalues = result.corr(
                        method=lambda x, y: spearmanr(x, y)[1]
                    ) - np.eye(*correlations.shape)
                else:
                    raise ValueError()
                pvalues.reset_index().to_feather(
                    main_dir
                    / f"data/for_graphs/same_method/correlation_{method}_pvalues__{correlation_measure}.feather"
                )
    elif strategy == "multi":
        Path(main_dir / "data/for_graphs/multi/").mkdir(parents=True, exist_ok=True)
        dict_of_results = {
            "ccr_multi_to_en": add_language_iso_column(
                load_ccr(which="multi_to_en", explicit_manifesto_id=True),
                matching_key="manifesto_id",
            ),
            "ccr_multi_to_multi": add_language_iso_column(
                load_ccr(which="multi_to_multi", explicit_manifesto_id=True),
                matching_key="manifesto_id",
            ),
            "ddr": add_language_iso_column(
                load_ddr(which="merged", explicit_manifesto_id=True),
                matching_key="manifesto_id",
            ),
            "mfd": add_language_iso_column(
                load_mfd(which="merged", explicit_manifesto_id=True),
                matching_key="manifesto_id",
            ),
        }
        for method, data in dict_of_results.items():
            correlations_for_df = []
            for language in languages:
                for pole in poles:
                    for foundation in foundations:
                        if correlation_measure == "kendall":
                            correlation_result = kendalltau(
                                x=data[data["language_iso"] == language][
                                    f"english_anchor_{foundation}.{pole}"
                                ],
                                y=data[data["language_iso"] == language][
                                    f"{method}_{foundation}.{pole}"
                                ],
                            )
                        elif correlation_measure == "spearman":
                            correlation_result = spearmanr(
                                a=data[data["language_iso"] == language][
                                    f"english_anchor_{foundation}.{pole}"
                                ],
                                b=data[data["language_iso"] == language][
                                    f"{method}_{foundation}.{pole}"
                                ],
                            )
                        correlations_for_df.append(
                            {
                                "language": language,
                                "foundation": foundation,
                                "pole": pole,
                                "correlation": correlation_result.statistic,
                                "pvalue": correlation_result.pvalue,
                            }
                        )

            df_correlations_dicts = pd.DataFrame(correlations_for_df)
            df_correlations_dicts.to_feather(
                f"data/for_graphs/multi/correlations_{method}_{correlation_measure}.feather"
            )


def score_for_regression():
    all_data = pd.read_feather(main_dir / "data/source/manifesto_corpus.feather")
    all_data["length_of_statement"] = all_data.text_en.apply(lambda x: len(x))
    Path(main_dir / "data/for_graphs/regression/").mkdir(parents=True, exist_ok=True)

    all_data.to_feather(main_dir / "data/for_graphs/regression/scored_data.feather")


def get_meta_data_file():
    full_corpus = pd.read_feather("data/source/manifesto_corpus.feather")
    return full_corpus[["manifesto_id", "id_for_project"]]
